/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once

#include <proxygen/lib/http/codec/HTTPCodec.h>
#include <proxygen/lib/http/webtransport/FlowController.h>
#include <proxygen/lib/http/webtransport/WebTransport.h>
#include <quic/QuicConstants.h>

namespace {
constexpr uint64_t kDefaultWTReceiveWindow = 1'048'576;
constexpr uint64_t kDefaultTargetConcurrentStreams = 50;
constexpr double kFlowControlThreshold = 0.5;
} // namespace

namespace proxygen {

class WebTransportImpl : public WebTransport {
 public:
  class TransportProvider {
   public:
    virtual ~TransportProvider() = default;

    virtual folly::Expected<HTTPCodec::StreamID, WebTransport::ErrorCode>
    newWebTransportBidiStream() = 0;

    virtual folly::Expected<HTTPCodec::StreamID, WebTransport::ErrorCode>
    newWebTransportUniStream() = 0;

    virtual folly::SemiFuture<folly::Unit> awaitUniStreamCredit() = 0;

    virtual folly::SemiFuture<folly::Unit> awaitBidiStreamCredit() = 0;

    virtual bool canCreateUniStream() = 0;

    virtual bool canCreateBidiStream() = 0;

    virtual folly::Expected<FCState, WebTransport::ErrorCode>
    sendWebTransportStreamData(HTTPCodec::StreamID /*id*/,
                               std::unique_ptr<folly::IOBuf> /*data*/,
                               bool /*eof*/,
                               ByteEventCallback* /* byteEventCallback */) = 0;

    virtual folly::Expected<folly::Unit, WebTransport::ErrorCode> sendWTMaxData(
        uint64_t maxData) = 0;

    virtual folly::Expected<folly::Unit, WebTransport::ErrorCode>
    sendWTMaxStreams(uint64_t maxStreams, bool isBidi) = 0;

    virtual folly::Expected<folly::Unit, WebTransport::ErrorCode>
    notifyPendingWriteOnStream(HTTPCodec::StreamID,
                               quic::StreamWriteCallback* wcb) = 0;

    virtual folly::Expected<folly::Unit, WebTransport::ErrorCode>
    resetWebTransportEgress(HTTPCodec::StreamID /*id*/,
                            uint32_t /*errorCode*/) = 0;

    virtual folly::Expected<folly::Unit, WebTransport::ErrorCode>
        setWebTransportStreamPriority(HTTPCodec::StreamID /*id*/,
                                      HTTPPriority /*pri*/) = 0;

    virtual folly::Expected<std::pair<std::unique_ptr<folly::IOBuf>, bool>,
                            WebTransport::ErrorCode>
    readWebTransportData(HTTPCodec::StreamID /*id*/, size_t /*max*/) = 0;

    virtual folly::Expected<folly::Unit, WebTransport::ErrorCode>
    initiateReadOnBidiStream(HTTPCodec::StreamID /*id*/,
                             quic::StreamReadCallback* /*readCallback*/) = 0;

    virtual folly::Expected<folly::Unit, WebTransport::ErrorCode>
        pauseWebTransportIngress(HTTPCodec::StreamID /*id*/) = 0;

    virtual folly::Expected<folly::Unit, WebTransport::ErrorCode>
        resumeWebTransportIngress(HTTPCodec::StreamID /*id*/) = 0;

    virtual folly::Expected<folly::Unit, WebTransport::ErrorCode>
        stopReadingWebTransportIngress(
            HTTPCodec::StreamID /*id*/,
            folly::Optional<uint32_t> /*errorCode*/) = 0;
    virtual folly::Expected<folly::Unit, WebTransport::ErrorCode> sendDatagram(
        std::unique_ptr<folly::IOBuf> /*datagram*/) = 0;

    virtual const folly::SocketAddress& getLocalAddress() const = 0;
    virtual const folly::SocketAddress& getPeerAddress() const = 0;

    [[nodiscard]] virtual quic::TransportInfo getTransportInfo() const {
      return {};
    }

    virtual bool usesEncodedApplicationErrorCodes() = 0;

    virtual uint64_t getWTInitialSendWindow() const {
      return quic::kMaxVarInt;
    }
  };

  class SessionProvider {
   public:
    virtual ~SessionProvider() = default;

    virtual void refreshTimeout() {
    }

    virtual folly::Expected<folly::Unit, WebTransport::ErrorCode> closeSession(
        folly::Optional<uint32_t> /*error*/) = 0;
  };

  WebTransportImpl(TransportProvider& tp, SessionProvider& sp)
      : tp_(tp), sp_(sp), sendFlowController_(tp_.getWTInitialSendWindow()) {
  }

  ~WebTransportImpl() override = default;

  void destroy();

  // WT API
  folly::Expected<WebTransport::StreamWriteHandle*, WebTransport::ErrorCode>
  createUniStream() override {
    return newWebTransportUniStream();
  }

  folly::Expected<WebTransport::BidiStreamHandle, WebTransport::ErrorCode>
  createBidiStream() override {
    return newWebTransportBidiStream();
  }
  folly::SemiFuture<folly::Unit> awaitUniStreamCredit() override {
    return tp_.awaitUniStreamCredit();
  }
  folly::SemiFuture<folly::Unit> awaitBidiStreamCredit() override {
    return tp_.awaitBidiStreamCredit();
  }
  folly::Expected<folly::SemiFuture<StreamData>, WebTransport::ErrorCode>
  readStreamData(uint64_t id) override {
    auto it = wtIngressStreams_.find(id);
    if (it == wtIngressStreams_.end()) {
      return folly::makeUnexpected(WebTransport::ErrorCode::INVALID_STREAM_ID);
    }
    return it->second.readStreamData();
  }
  folly::Expected<folly::Unit, WebTransport::ErrorCode> stopSending(
      uint64_t id, uint32_t error) override {
    auto it = wtIngressStreams_.find(id);
    if (it == wtIngressStreams_.end()) {
      return folly::makeUnexpected(WebTransport::ErrorCode::INVALID_STREAM_ID);
    }
    return it->second.stopSending(error);
  }
  folly::Expected<FCState, ErrorCode> writeStreamData(
      uint64_t id,
      std::unique_ptr<folly::IOBuf> data,
      bool fin,
      ByteEventCallback* deliveryCallback) override {
    auto it = wtEgressStreams_.find(id);
    if (it == wtEgressStreams_.end()) {
      return folly::makeUnexpected(WebTransport::ErrorCode::INVALID_STREAM_ID);
    }
    return it->second.writeStreamData(std::move(data), fin, deliveryCallback);
  }
  folly::Expected<folly::Unit, WebTransport::ErrorCode> resetStream(
      uint64_t id, uint32_t error) override {
    auto it = wtEgressStreams_.find(id);
    if (it == wtEgressStreams_.end()) {
      return folly::makeUnexpected(WebTransport::ErrorCode::INVALID_STREAM_ID);
    }
    return it->second.resetStream(error);
  }
  folly::Expected<folly::Unit, WebTransport::ErrorCode> setPriority(
      uint64_t id, uint8_t level, uint32_t order, bool incremental) override {
    auto it = wtEgressStreams_.find(id);
    if (it == wtEgressStreams_.end()) {
      return folly::makeUnexpected(WebTransport::ErrorCode::INVALID_STREAM_ID);
    }
    return it->second.setPriority(level, order, incremental);
  }
  folly::Expected<folly::SemiFuture<uint64_t>, ErrorCode> awaitWritable(
      uint64_t streamId) override {
    auto it = wtEgressStreams_.find(streamId);
    if (it == wtEgressStreams_.end()) {
      return folly::makeUnexpected(WebTransport::ErrorCode::INVALID_STREAM_ID);
    }
    return it->second.awaitWritable();
  }
  folly::Expected<folly::Unit, WebTransport::ErrorCode> sendDatagram(
      std::unique_ptr<folly::IOBuf> datagram) override {
    if (sessionCloseError_.has_value()) {
      return folly::makeUnexpected(WebTransport::ErrorCode::SESSION_TERMINATED);
    }
    // This can bypass the size and state machine checks in
    // HTTPTransaction::sendDatagram
    if (!tp_.sendDatagram(std::move(datagram))) {
      return folly::makeUnexpected(WebTransport::ErrorCode::SEND_ERROR);
    }
    return folly::unit;
  }

  const folly::SocketAddress& getLocalAddress() const override {
    return tp_.getLocalAddress();
  }

  const folly::SocketAddress& getPeerAddress() const override {
    return tp_.getPeerAddress();
  }

  [[nodiscard]] quic::TransportInfo getTransportInfo() const override {
    return tp_.getTransportInfo();
  }

  folly::Expected<folly::Unit, WebTransport::ErrorCode> closeSession(
      folly::Optional<uint32_t> error) override {
    return sp_.closeSession(error);
  }

  class StreamWriteHandle
      : public WebTransport::StreamWriteHandle
      , public quic::StreamWriteCallback {
   public:
    StreamWriteHandle(WebTransportImpl& tp, HTTPCodec::StreamID id);

    ~StreamWriteHandle() override {
      cs_.requestCancellation();
    }

    folly::Expected<FCState, WebTransport::ErrorCode> writeStreamData(
        std::unique_ptr<folly::IOBuf> data,
        bool fin,
        ByteEventCallback* deliveryCallback) override;

    folly::Expected<folly::Unit, WebTransport::ErrorCode> resetStream(
        uint32_t errorCode) override {
      return impl_.resetWebTransportEgress(id_, errorCode);
    }

    folly::Expected<folly::Unit, WebTransport::ErrorCode> setPriority(
        uint8_t level, uint32_t order, bool incremental) override {
      return impl_.tp_.setWebTransportStreamPriority(
          getID(), {level, incremental, order});
    }
    folly::Expected<folly::SemiFuture<uint64_t>, ErrorCode> awaitWritable()
        override;

    void onStopSending(uint32_t errorCode);
    folly::Expected<WebTransport::FCState, WebTransport::ErrorCode>
    flushBufferedWrites();
    void maybeResolveWritePromise(uint64_t maxToSend);

    // TODO: what happens to promise_ if this stream is reset or the
    // conn closes

   private:
    /*
     * Used to store buffered writes when we have insufficient flow control.
     */
    struct BufferedWrite {
      folly::IOBufQueue buf;
      ByteEventCallback* deliveryCallback;
      bool fin;

      BufferedWrite(std::unique_ptr<folly::IOBuf> data,
                    ByteEventCallback* callback,
                    bool fin)
          : buf(folly::IOBufQueue::cacheChainLength()),
            deliveryCallback(callback),
            fin(fin) {
        if (data) {
          buf.append(std::move(data));
        }
      }
    };

    void onStreamWriteReady(quic::StreamId id, uint64_t) noexcept override;

    WebTransportImpl& impl_;
    folly::Promise<uint64_t> writePromise_;
    std::list<BufferedWrite> bufferedWrites_;
    bool streamWriteReady_{false};
  };

  class StreamReadHandle
      : public WebTransport::StreamReadHandle
      , public quic::StreamReadCallback {
   public:
    StreamReadHandle(WebTransportImpl& impl, HTTPCodec::StreamID id);

    ~StreamReadHandle() override {
      cs_.requestCancellation();
    }

    folly::SemiFuture<WebTransport::StreamData> readStreamData() override;

    folly::Expected<folly::Unit, WebTransport::ErrorCode> stopSending(
        uint32_t error) override {
      return impl_.stopReadingWebTransportIngress(id_, error);
    }

    FCState dataAvailable(std::unique_ptr<folly::IOBuf> data, bool eof);
    void deliverReadError(const folly::exception_wrapper& ex);
    [[nodiscard]] bool open() const {
      return !eof_ && !ex_;
    }

    // quic::StreamReadCallback overrides
    void readAvailable(quic::StreamId id) noexcept override;
    void readError(quic::StreamId id, quic::QuicError error) noexcept override;

   private:
    WebTransportImpl& impl_;
    folly::Promise<WebTransport::StreamData> readPromise_;
    folly::IOBufQueue buf_{folly::IOBufQueue::cacheChainLength()};
    bool eof_{false};
  };

 private:
  folly::Expected<WebTransport::FCState, WebTransport::ErrorCode>
  sendWebTransportStreamData(HTTPCodec::StreamID id,
                             std::unique_ptr<folly::IOBuf> data,
                             bool eof,
                             ByteEventCallback* deliveryCallback);

  folly::Expected<folly::Unit, WebTransport::ErrorCode> resetWebTransportEgress(
      HTTPCodec::StreamID id, uint32_t errorCode);

  folly::Expected<folly::Unit, WebTransport::ErrorCode>
  stopReadingWebTransportIngress(HTTPCodec::StreamID id,
                                 folly::Optional<uint32_t> errorCode);

  folly::Expected<WebTransport::BidiStreamHandle, WebTransport::ErrorCode>
  newWebTransportBidiStream();

  folly::Expected<WebTransport::StreamWriteHandle*, WebTransport::ErrorCode>
  newWebTransportUniStream();

  TransportProvider& tp_;
  SessionProvider& sp_;
  using WTEgressStreamMap = std::map<HTTPCodec::StreamID, StreamWriteHandle>;
  using WTIngressStreamMap = std::map<HTTPCodec::StreamID, StreamReadHandle>;
  WTEgressStreamMap wtEgressStreams_;
  WTIngressStreamMap wtIngressStreams_;
  // Initialize flow controllers with max value for backward compatibility and
  // to ensure no functional change yet.
  FlowController sendFlowController_{quic::kMaxVarInt};
  FlowController recvFlowController_{quic::kMaxVarInt};
  folly::Optional<uint32_t> sessionCloseError_;
  uint64_t bytesRead_{};

  struct StreamFlowControl {
    uint64_t maxStreamID{quic::kMaxVarInt};
    uint64_t numClosedStreams{};
    uint64_t targetConcurrentStreams{kDefaultTargetConcurrentStreams};
  };

  StreamFlowControl uniStreamFlowControl_;
  StreamFlowControl bidiStreamFlowControl_;

 public:
  [[nodiscard]] bool isSessionTerminated() const {
    return sessionCloseError_.has_value();
  }

  [[nodiscard]] folly::Optional<uint32_t> getSessionCloseError() const {
    return sessionCloseError_;
  }

  void terminateSession(const folly::Optional<uint32_t>& error = folly::none) {
    sessionCloseError_ =
        error.has_value() ? error : folly::Optional<uint32_t>(0);
    terminateSessionStreams(WebTransport::kSessionGone, "session terminated");
  }

 protected:
  void terminateSessionStreams(uint32_t errorCode, const std::string& reason);

 public:
  // Calls from Transport
  struct BidiStreamHandle {
    StreamReadHandle* readHandle;
    StreamWriteHandle* writeHandle;
  };
  BidiStreamHandle onWebTransportBidiStream(HTTPCodec::StreamID id);

  WebTransportImpl::StreamReadHandle* onWebTransportUniStream(
      HTTPCodec::StreamID id);

  void onWebTransportStopSending(HTTPCodec::StreamID id, uint32_t errorCode);

  void onMaxData(uint64_t maxData) noexcept;
  void onMaxStreams(uint64_t maxStreams, bool isBidi) noexcept;

  void maybeGrantFlowControl(uint64_t bytesRead);
  void maybeGrantStreamCredit(HTTPCodec::StreamID id,
                              bool closingReadHandle,
                              bool closingWriteHandle);
  [[nodiscard]] bool shouldGrantFlowControl() const;
  [[nodiscard]] bool shouldGrantStreamCredit(bool isBidi) const;

  void closeEgressStream(HTTPCodec::StreamID id);
  void closeIngressStream(HTTPCodec::StreamID id);

  bool isBidirectional(HTTPCodec::StreamID id) const {
    return (id & 0b10) == 0;
  }

  void setFlowControlLimits(uint64_t sendWindow = quic::kMaxVarInt,
                            uint64_t recvWindow = quic::kMaxVarInt) {
    sendFlowController_ = FlowController(sendWindow);
    recvFlowController_ = FlowController(recvWindow);
  }

  void setUniStreamFlowControl(uint64_t maxStreamId,
                               uint64_t targetConcurrentStreams) {
    uniStreamFlowControl_.maxStreamID = maxStreamId;
    uniStreamFlowControl_.targetConcurrentStreams = targetConcurrentStreams;
  }

  void setBidiStreamFlowControl(uint64_t maxStreamId,
                                uint64_t targetConcurrentStreams) {
    bidiStreamFlowControl_.maxStreamID = maxStreamId;
    bidiStreamFlowControl_.targetConcurrentStreams = targetConcurrentStreams;
  }
};

} // namespace proxygen
